{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import jit, grad, vmap\n",
    "from jax import random\n",
    "\n",
    "from functools import wraps, partial\n",
    "from jax import core\n",
    "from jax import lax\n",
    "from jax._src.util import safe_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "foo\n",
      "=====\n",
      "invars: [a]\n",
      "outvars: [b]\n",
      "constvars: []\n",
      "equation: [a, 1] add [b] {}\n",
      "\n",
      "jaxpr: { lambda  ; a.\n",
      "  let b = add a 1\n",
      "  in (b,) }\n",
      "\n",
      "bar\n",
      "=====\n",
      "invars: [a, b, c]\n",
      "outvars: [g, c]\n",
      "constvars: []\n",
      "equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': None}\n",
      "equation: [d, b] add [e] {}\n",
      "equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}\n",
      "equation: [e, f] add [g] {}\n",
      "\n",
      "jaxpr: { lambda  ; a b c.\n",
      "  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
      "                       precision=None\n",
      "                       preferred_element_type=None ] a c\n",
      "      e = add d b\n",
      "      f = broadcast_in_dim[ broadcast_dimensions=(  )\n",
      "                            shape=(5,) ] 1.0\n",
      "      g = add e f\n",
      "  in (g, c) }\n"
     ]
    }
   ],
   "source": [
    "def examine_jaxpr(closed_jaxpr):\n",
    "  jaxpr = closed_jaxpr.jaxpr\n",
    "  print(\"invars:\", jaxpr.invars)\n",
    "  print(\"outvars:\", jaxpr.outvars)\n",
    "  print(\"constvars:\", jaxpr.constvars)\n",
    "  for eqn in jaxpr.eqns:\n",
    "    print(\"equation:\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\n",
    "  print()\n",
    "  print(\"jaxpr:\", jaxpr)\n",
    "\n",
    "def foo(x):\n",
    "  return x + 1\n",
    "print(\"foo\")\n",
    "print(\"=====\")\n",
    "examine_jaxpr(jax.make_jaxpr(foo)(5))\n",
    "\n",
    "print()\n",
    "\n",
    "def bar(w, b, x):\n",
    "  return jnp.dot(w, x) + b + jnp.ones(5), x\n",
    "print(\"bar\")\n",
    "print(\"=====\")\n",
    "examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "def slice_closed_jaxpr(closed_jaxpr, start=None, end=None):\n",
    "#     print(\"closed_jaxpr.consts:\", closed_jaxpr.consts)\n",
    "#     print(\"closed_jaxpr.jaxpr.constvars:\", closed_jaxpr.jaxpr.constvars)\n",
    "#     print(\"closed_jaxpr.jaxpr.invars:\", closed_jaxpr.jaxpr.invars)\n",
    "#     print(\"closed_jaxpr.jaxpr.outvars:\", closed_jaxpr.jaxpr.outvars)\n",
    "    invars = set(closed_jaxpr.jaxpr.invars)\n",
    "    consts_dir = OrderedDict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\n",
    "    \n",
    "    pred_intermediate_vars = set()\n",
    "    \n",
    "    slice_consts_dir = OrderedDict()\n",
    "    slice_invars = []\n",
    "    slice_outvars = []\n",
    "    slice_eqns = []\n",
    "    slice_intermediate_vars = set()\n",
    "\n",
    "    succ_intermediate_vars = set()\n",
    "    \n",
    "    start = start if start is not None else 0\n",
    "    end = end if end is not None else len(closed_jaxpr.jaxpr.eqns)\n",
    "    \n",
    "    for index, eqn in enumerate(closed_jaxpr.jaxpr.eqns):\n",
    "#         print(index, eqn, eqn.invars, eqn.outvars)\n",
    "        if index < start:\n",
    "            pred_intermediate_vars.update(eqn.outvars)\n",
    "        elif start <= index < end:\n",
    "            slice_eqns.append(eqn)\n",
    "            for var in eqn.invars:\n",
    "                if isinstance(var, core.Literal):\n",
    "                    continue\n",
    "                elif var in consts_dir:\n",
    "                    if var not in slice_consts_dir:\n",
    "                        slice_consts_dir[var] = consts_dir[var]\n",
    "                elif (var in invars) or (var in pred_intermediate_vars):\n",
    "                    if var not in slice_invars: # FIXME: this is O(n^2)\n",
    "                        slice_invars.append(var)\n",
    "                else:\n",
    "                    assert var in slice_intermediate_vars\n",
    "            slice_intermediate_vars.update(eqn.outvars)\n",
    "        else:  # end <= index\n",
    "            for var in eqn.invars:\n",
    "                if isinstance(var, core.Literal):\n",
    "                    continue\n",
    "                elif (var in invars) or (var in pred_intermediate_vars):\n",
    "                    if var not in slice_invars: # FIXME: this is O(n^2)\n",
    "                        slice_invars.append(var)\n",
    "                    if var not in slice_outvars: # FIXME: this is O(n^2)\n",
    "                        slice_outvars.append(var)\n",
    "                elif var in slice_intermediate_vars:\n",
    "                    if var not in slice_outvars: # FIXME: this is O(n^2)\n",
    "                        slice_outvars.append(var)                    \n",
    "                else:\n",
    "                    assert (var in consts_dir) or (var in succ_intermediate_vars)\n",
    "            succ_intermediate_vars.update(eqn.outvars)\n",
    "\n",
    "    for var in closed_jaxpr.jaxpr.outvars:\n",
    "        if (var in invars) or (var in pred_intermediate_vars):\n",
    "            if var not in slice_invars: # FIXME: this is O(n^2)\n",
    "                slice_invars.append(var)\n",
    "            if var not in slice_outvars: # FIXME: this is O(n^2)\n",
    "                slice_outvars.append(var)\n",
    "        elif var in slice_intermediate_vars:\n",
    "            if var not in slice_outvars: # FIXME: this is O(n^2)\n",
    "                slice_outvars.append(var)                    \n",
    "        else:\n",
    "            assert (var in consts_dir) or (var in succ_intermediate_vars)\n",
    "\n",
    "#     print(\"pred_intermediate_vars\", pred_intermediate_vars)\n",
    "#     print(\"slice_consts_dir\", slice_consts_dir)\n",
    "#     print(\"slice_invars\", slice_invars)\n",
    "#     print(\"slice_outvars\", slice_outvars)\n",
    "#     print(\"slice_eqns\", slice_eqns)\n",
    "#     print(\"slice_intermediate_vars\", slice_intermediate_vars)\n",
    "#     print(\"succ_intermediate_vars\", succ_intermediate_vars)\n",
    "    slice_jaxpr = core.Jaxpr(slice_consts_dir.keys(), slice_invars, slice_outvars, slice_eqns)\n",
    "    slice_closed_jaxpr = core.ClosedJaxpr(slice_jaxpr, slice_consts_dir.values())\n",
    "    return slice_closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c = broadcast_in_dim[ broadcast_dimensions=(  )\n",
       "                            shape=(5,) ] 1.0\n",
       "      d = sin c\n",
       "      e = tanh a\n",
       "      f = mul d e\n",
       "      g = sin f\n",
       "      h = cos g\n",
       "      i = exp h\n",
       "  in (i, b) }"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(x, z):\n",
    "    y = jnp.sin(jnp.ones_like(x))\n",
    "    x = y * jnp.tanh(x)\n",
    "    x = jnp.sin(x)\n",
    "    x = jnp.cos(x)\n",
    "    x = jnp.exp(x)\n",
    "    return x, z\n",
    "closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5), jnp.ones(6))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c = broadcast_in_dim[ broadcast_dimensions=(  )\n",
       "                            shape=(5,) ] 1.0\n",
       "      d = sin c\n",
       "      e = tanh a\n",
       "      f = mul d e\n",
       "  in (f, b) }"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "closed_jaxpr_slice1 = slice_closed_jaxpr(closed_jaxpr, start=0, end=4)\n",
    "closed_jaxpr_slice1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; f b.\n",
       "  let g = sin f\n",
       "      h = cos g\n",
       "      i = exp h\n",
       "  in (i, b) }"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "closed_jaxpr_slice2 = slice_closed_jaxpr(closed_jaxpr, start=4)\n",
    "closed_jaxpr_slice2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n",
       " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "core.jaxpr_as_fun(closed_jaxpr)(jnp.ones(5), jnp.ones(6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[DeviceArray([0.6408594, 0.6408594, 0.6408594, 0.6408594, 0.6408594], dtype=float32), DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n",
       " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intermediate = core.jaxpr_as_fun(closed_jaxpr_slice1)(jnp.ones(5), jnp.ones(6))\n",
    "print(intermediate)\n",
    "core.jaxpr_as_fun(closed_jaxpr_slice2)(*intermediate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\n",
       " DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intermediate = jit(core.jaxpr_as_fun(closed_jaxpr_slice1))(jnp.ones(5), jnp.ones(6))\n",
    "jit(core.jaxpr_as_fun(closed_jaxpr_slice2))(*intermediate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: merge with Lianmin's code\n",
    "# TODO: PyTree inputs\n",
    "# Q: How about lax.cond & lax.while?\n",
    "#    Ideally we should inline lax.cond & lax.while\n",
    "# Q: How about backward?\n",
    "# Q: How to slice a computation into different stages, given that jaxpr is actually a graph?\n",
    "# Why JaxPR? Try XLA\n",
    "# Forward & backward device assignment (very general)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] a b\n",
       "      d = exp c\n",
       "  in (d,) }"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# @jax.jit\n",
    "def matmul(w, x):\n",
    "    return w @ x\n",
    "\n",
    "def f(w, x):\n",
    "    x = matmul(w, x)\n",
    "    x = jnp.exp(x)\n",
    "    return x\n",
    "\n",
    "closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] a b\n",
       "      d = exp c\n",
       "  in (d,) }"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax import core\n",
    "from jax.lib import xla_client\n",
    "from jax.interpreters import xla, ad\n",
    "\n",
    "pipeline_start_p = core.Primitive(\"pipeline_start\")  # Create the primitive\n",
    "pipeline_start_p.multiple_results = True\n",
    "pipeline_end_p = core.Primitive(\"pipeline_end\")  # Create the primitive\n",
    "pipeline_end_p.multiple_results = True\n",
    "\n",
    "def mark_pipeline_start(*args, name):\n",
    "    return pipeline_start_p.bind(*args, name=name)\n",
    "\n",
    "def mark_pipeline_end(*args, name):\n",
    "    return pipeline_end_p.bind(*args, name=name)\n",
    "\n",
    "\n",
    "def pipeline_impl(*args, name):\n",
    "    if len(args) == 0:\n",
    "        return (None, )\n",
    "    else:\n",
    "        return args\n",
    "\n",
    "def pipeline_abstract_eval(*args, name):\n",
    "    if len(args) == 0:\n",
    "        return (core.abstract_unit, )\n",
    "    else:\n",
    "        return args\n",
    "\n",
    "def pipeline_xla_translation(c, *args, name):\n",
    "    if len(args) == 0:\n",
    "        return xla_client.ops.Tuple(c, (xla_client.ops.Constant(c, np.float32(0.0)), ))\n",
    "    else:\n",
    "        return xla_client.ops.Tuple(c, args)\n",
    "\n",
    "def pipeline_start_value_and_jvp(arg_values, arg_tangents, name):\n",
    "    primal_outs = mark_pipeline_start(*arg_values, name=name)\n",
    "    tangent_outs = mark_pipeline_start(*arg_tangents, name=\"jvp_\" + name)\n",
    "    return primal_outs, tangent_outs\n",
    "    \n",
    "def pipeline_start_transpose(ct, *args, name):\n",
    "    res = mark_pipeline_end(*ct, name=\"vjp_\" + name)\n",
    "    return res\n",
    "\n",
    "def pipeline_end_value_and_jvp(arg_values, arg_tangents, name):\n",
    "    primal_outs = mark_pipeline_end(*arg_values, name=name)\n",
    "    tangent_outs = mark_pipeline_end(*arg_tangents, name=\"jvp_\" + name)\n",
    "    return primal_outs, tangent_outs\n",
    "    \n",
    "def pipeline_end_transpose(ct, *args, name):\n",
    "    res = mark_pipeline_start(*ct, name=\"vjp_\" + name)\n",
    "    return res\n",
    "\n",
    "    \n",
    "pipeline_start_p.def_impl(pipeline_impl)\n",
    "pipeline_start_p.def_abstract_eval(pipeline_abstract_eval)\n",
    "xla.backend_specific_translations['cpu'][pipeline_start_p] = pipeline_xla_translation\n",
    "xla.backend_specific_translations['gpu'][pipeline_start_p] = pipeline_xla_translation\n",
    "xla.backend_specific_translations['tpu'][pipeline_start_p] = pipeline_xla_translation\n",
    "ad.primitive_jvps[pipeline_start_p] = pipeline_start_value_and_jvp\n",
    "ad.primitive_transposes[pipeline_start_p] = pipeline_start_transpose\n",
    "\n",
    "pipeline_end_p.def_impl(pipeline_impl)\n",
    "pipeline_end_p.def_abstract_eval(pipeline_abstract_eval)\n",
    "xla.backend_specific_translations['cpu'][pipeline_end_p] = pipeline_xla_translation\n",
    "xla.backend_specific_translations['gpu'][pipeline_end_p] = pipeline_xla_translation\n",
    "xla.backend_specific_translations['tpu'][pipeline_end_p] = pipeline_xla_translation\n",
    "ad.primitive_jvps[pipeline_end_p] = pipeline_end_value_and_jvp\n",
    "ad.primitive_transposes[pipeline_end_p] = pipeline_end_transpose\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c d = pipeline_start[ name=1 ] a b\n",
       "      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] c d\n",
       "      f = pipeline_end[ name=1 ] e\n",
       "      g = pipeline_start[ name=2 ] f\n",
       "      h = exp g\n",
       "      i = reduce_sum[ axes=(0,) ] h\n",
       "      _ = mul i 7.0\n",
       "      j = pipeline_end[ name=2 ] i\n",
       "  in (j,) }"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f_original(w, x):\n",
    "    x = matmul(w, x)\n",
    "    x = jnp.exp(x)\n",
    "    x = jnp.sum(x)\n",
    "    y = 7 * x\n",
    "    return x\n",
    "\n",
    "def f(w, x):\n",
    "    w, x = mark_pipeline_start(w, x, name=\"1\")\n",
    "    x = matmul(w, x)\n",
    "    x, = mark_pipeline_end(x, name=\"1\")\n",
    "    x, = mark_pipeline_start(x, name=\"2\")\n",
    "    x = jnp.exp(x)\n",
    "    x = jnp.sum(x)\n",
    "    y = 7 * x\n",
    "    x, = mark_pipeline_end(x, name=\"2\")\n",
    "    return x\n",
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(742.0658, dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.jit(f)(jnp.ones((5, 5)), jnp.ones(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c d = pipeline_start[ name=1 ] a b\n",
       "      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] c d\n",
       "      f = pipeline_end[ name=1 ] e\n",
       "      g = pipeline_start[ name=2 ] f\n",
       "      h = exp g\n",
       "      i = reduce_sum[ axes=(0,) ] h\n",
       "      _ = mul i 7.0\n",
       "      _ = pipeline_end[ name=2 ] i\n",
       "      j = pipeline_start[ name=vjp_jvp_2 ] 1.0\n",
       "      k = broadcast_in_dim[ broadcast_dimensions=(  )\n",
       "                            shape=(5,) ] j\n",
       "      l = mul k h\n",
       "      m = pipeline_end[ name=vjp_jvp_2 ] l\n",
       "      n = pipeline_start[ name=vjp_jvp_1 ] m\n",
       "      o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] n c\n",
       "      p = dot_general[ dimension_numbers=(((), ()), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] n d\n",
       "      q r = pipeline_end[ name=vjp_jvp_1 ] p o\n",
       "  in (q, r) }"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n",
       "              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n",
       "              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n",
       "              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\n",
       "              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316]],            dtype=float32),\n",
       " DeviceArray([742.0658, 742.0658, 742.0658, 742.0658, 742.0658], dtype=float32))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(jax.grad(f, argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b c d.\n",
       "  let e f = pipeline_start[ name=1 ] a b\n",
       "      g h = pipeline_start[ name=jvp_1 ] c d\n",
       "      i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] e f\n",
       "      j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] g f\n",
       "      k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] e h\n",
       "      l = add_any j k\n",
       "      m = pipeline_end[ name=1 ] i\n",
       "      n = pipeline_end[ name=jvp_1 ] l\n",
       "      o = pipeline_start[ name=2 ] m\n",
       "      p = pipeline_start[ name=jvp_2 ] n\n",
       "      q = exp o\n",
       "      r = mul p q\n",
       "      s = reduce_sum[ axes=(0,) ] q\n",
       "      t = reduce_sum[ axes=(0,) ] r\n",
       "      _ = mul s 7.0\n",
       "      _ = mul t 7.0\n",
       "      u = pipeline_end[ name=2 ] s\n",
       "      v = pipeline_end[ name=jvp_2 ] t\n",
       "  in (u, v) }"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_p = Primitive('pipeline')\n",
    "pipeline_p.multiple_results = True\n",
    "\n",
    "def mark_pipeline(*args, name, mark_type):\n",
    "    if mark_type not in ('start', 'end', 'jvp_start', 'jvp_end'):\n",
    "        raise ValueError('Unknown mark type: %s' % mark_type)\n",
    "    return pipeline_p.bind(*args, name=name, mark_type=mark_type)\n",
    "\n",
    "def _pipeline_impl(*args, **kwargs):\n",
    "    # The pipeline marker acts as an identity function\n",
    "    return args if len(args) > 0 else (None, )\n",
    "\n",
    "def _pipeline_abstract_eval(*args, **kwargs):\n",
    "    return args if len(args) > 0 else (abstract_unit, )\n",
    "\n",
    "def _pipeline_xla_translation(c, *args, **kwargs):\n",
    "    return xc.ops.Tuple(c, args) if len(args) > 0 else xc.ops.Tuple(c, (xc.ops.Constant(c, np.float32(0.0)), ))\n",
    "\n",
    "def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type):\n",
    "    primal_outs = mark_pipeline(*arg_values, name=name, mark_type=mark_type)\n",
    "    # TODO(zhuohan): Check the semantics here works for higher order gradients.\n",
    "    if mark_type == \"start\" or mark_type == \"jvp_start\":\n",
    "        tangent_mark_type = \"jvp_start\"\n",
    "    elif mark_type == \"end\" or mark_type == \"jvp_end\":\n",
    "        tangent_mark_type = \"jvp_end\"\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mark_type\")\n",
    "    tangent_outs = mark_pipeline(*arg_tangents, name=name, mark_type=tangent_mark_type)\n",
    "    return primal_outs, tangent_outs\n",
    "\n",
    "def _pipeline_transpose(ct, *args, name, mark_type):\n",
    "    # TODO(zhuohan): Check the semantics here works for higher order gradients.\n",
    "    if mark_type == \"start\" or mark_type == \"jvp_start\":\n",
    "        transposed_mark_type = \"end\"\n",
    "    elif mark_type == \"end\" or mark_type == \"jvp_end\":\n",
    "        transposed_mark_type = \"start\"\n",
    "    else:\n",
    "        raise ValueError(\"Invalid mark_type\")\n",
    "    res = mark_pipeline(*ct, name=name, mark_type=transposed_mark_type)\n",
    "    return res\n",
    "\n",
    "pipeline_p.def_impl(_pipeline_impl)\n",
    "pipeline_p.def_abstract_eval(_pipeline_abstract_eval)\n",
    "xla.translations[pipeline_p] = _pipeline_xla_translation\n",
    "ad.primitive_jvps[pipeline_p] = _pipeline_value_and_jvp\n",
    "ad.primitive_transposes[pipeline_p] = _pipeline_transpose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c d = pipeline[ mark_type=start\n",
       "                      name=1 ] a b\n",
       "      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] c d\n",
       "      f = pipeline[ mark_type=end\n",
       "                    name=1 ] e\n",
       "      g = pipeline[ mark_type=start\n",
       "                    name=2 ] f\n",
       "      h = exp g\n",
       "      i = reduce_sum[ axes=(0,) ] h\n",
       "      _ = mul i 7.0\n",
       "      j = pipeline[ mark_type=end\n",
       "                    name=2 ] i\n",
       "  in (j,) }"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def f(w, x):\n",
    "    w, x = mark_pipeline(w, x, name=\"1\", mark_type='start')\n",
    "    x = matmul(w, x)\n",
    "    x, = mark_pipeline(x, name=\"1\", mark_type='end')\n",
    "    x, = mark_pipeline(x, name=\"2\", mark_type='start')\n",
    "    x = jnp.exp(x)\n",
    "    x = jnp.sum(x)\n",
    "    y = 7 * x\n",
    "    x, = mark_pipeline(x, name=\"2\", mark_type='end')\n",
    "    return x\n",
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b.\n",
       "  let c d = pipeline[ mark_type=start\n",
       "                      name=1 ] a b\n",
       "      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] c d\n",
       "      f = pipeline[ mark_type=end\n",
       "                    name=1 ] e\n",
       "      g = pipeline[ mark_type=start\n",
       "                    name=2 ] f\n",
       "      h = exp g\n",
       "      i = reduce_sum[ axes=(0,) ] h\n",
       "      _ = mul i 7.0\n",
       "      _ = pipeline[ mark_type=end\n",
       "                    name=2 ] i\n",
       "      j = pipeline[ mark_type=start\n",
       "                    name=2 ] 1.0\n",
       "      k = broadcast_in_dim[ broadcast_dimensions=(  )\n",
       "                            shape=(5,) ] j\n",
       "      l = mul k h\n",
       "      m = pipeline[ mark_type=end\n",
       "                    name=2 ] l\n",
       "      n = pipeline[ mark_type=start\n",
       "                    name=1 ] m\n",
       "      o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] n c\n",
       "      p = dot_general[ dimension_numbers=(((), ()), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] n d\n",
       "      q r = pipeline[ mark_type=end\n",
       "                      name=1 ] p o\n",
       "  in (q, r) }"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ lambda  ; a b c d.\n",
       "  let e f = pipeline[ mark_type=start\n",
       "                      name=1 ] a b\n",
       "      g h = pipeline[ mark_type=jvp_start\n",
       "                      name=1 ] c d\n",
       "      i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] e f\n",
       "      j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] g f\n",
       "      k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
       "                       precision=None\n",
       "                       preferred_element_type=None ] e h\n",
       "      l = add_any j k\n",
       "      m = pipeline[ mark_type=end\n",
       "                    name=1 ] i\n",
       "      n = pipeline[ mark_type=jvp_end\n",
       "                    name=1 ] l\n",
       "      o = pipeline[ mark_type=start\n",
       "                    name=2 ] m\n",
       "      p = pipeline[ mark_type=jvp_start\n",
       "                    name=2 ] n\n",
       "      q = exp o\n",
       "      r = mul p q\n",
       "      s = reduce_sum[ axes=(0,) ] q\n",
       "      t = reduce_sum[ axes=(0,) ] r\n",
       "      _ = mul s 7.0\n",
       "      _ = mul t 7.0\n",
       "      u = pipeline[ mark_type=end\n",
       "                    name=2 ] s\n",
       "      v = pipeline[ mark_type=jvp_end\n",
       "                    name=2 ] t\n",
       "  in (u, v) }"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with jax.disable_jit():\n",
    "    closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\n",
    "closed_jaxpr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_anaconda3)",
   "language": "python",
   "name": "conda_anaconda3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
